import pandas as pd
import numpy as np
import torch
from transformers import pipeline, logging
from sklearn.metrics import matthews_corrcoef, accuracy_score, f1_score
from sklearn.utils import resample
from datasets import load_dataset, DatasetDict, Dataset
from tqdm import tqdm
import random
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option('future.no_silent_downcasting', True)
pd.options.mode.chained_assignment = None 
logging.set_verbosity_error() 

###########################
## Set device from CLI args
###########################
import os
import argparse
def get_device():
    # Allow --device on the command line
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", default=None, help="Device to use: cpu, mps, cuda")
    args, _ = parser.parse_known_args()  # ignore unknown args so script still works normally
    # Priority: CLI argument > DEVICE environment variable > default=cpu
    return args.device or os.getenv("DEVICE", "cuda")
DEVICE = get_device()
print(f"[INFO] Using DEVICE={DEVICE}")
# PyTorch setup
import torch
if DEVICE == "cpu":
    device = torch.device("cpu")
elif DEVICE == "mps":
    device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
elif DEVICE == "cuda":
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
else:
    raise ValueError(f"Unknown device {DEVICE}")
print(f"[INFO] Torch device set to {device}")

###########################
## Data, functions, etc.
###########################

def label_docs(model, docs_dict, batch_size = 8, device = device):
    """
    Passes documents through the pipeline. Returns a list of entail, not_entail labels
    """
    pipe = pipeline(task = 'text-classification', model = model, 
                    batch_size = batch_size, device = device, 
                    max_length = 512, truncation = True, 
                    torch_dtype = torch.bfloat16)
    res = pipe(docs_dict)
    res = [result['label'] for result in res]
    return res

#######################
## Out of Domain Tasks
#######################
ood = pd.read_csv('./data/out_domain_bench.csv')
ood.rename({'task_name':'task', 'labels':'entailment'}, axis = 1, inplace = True)
docs_dict = [{'text':ood.loc[i, 'text'], 'text_pair':ood.loc[i, 'hypothesis']} for i in ood.index]

# models that will be tested
models = ["MoritzLaurer/deberta-v3-base-zeroshot-v2.0", 
          "MoritzLaurer/deberta-v3-large-zeroshot-v2.0",
          "mlburnham/Political_DEBATE_DeBERTa_base_v1.1",
          "mlburnham/Political_DEBATE_DeBERTa_base_v1.1",
          "mlburnham/Political_DEBATE_ModernBERT_base_v1.0",
          "mlburnham/Political_DEBATE_ModernBERT_large_v1.0"]

# column names that will hold results
columns = ['base_nli',
           'large_nli',
           'base_debate',
           'large_debate',
           'base_modern',
           'large_modern']

for modname, col in zip(models, columns):
    res = label_docs(modname, docs_dict, batch_size = 8, device = device)
    ood[col] = res
    ood[col] = ood[col].replace({'entailment': 0, 'not_entailment': 1})
    print(modname + ' complete.')

ood.to_csv('./data/out_domain_bench.csv', index = False)

################
## NLI Datasets
################
nli = pd.read_csv('./data/nli_bench.csv')

nli.rename({'task_name':'task', 'labels':'entailment'}, axis = 1, inplace = True)
docs_dict = [{'text':nli.loc[i, 'text'], 'text_pair':nli.loc[i, 'hypothesis']} for i in nli.index]

# models that will be tested
models = ["MoritzLaurer/deberta-v3-base-zeroshot-v2.0", 
          "MoritzLaurer/deberta-v3-large-zeroshot-v2.0",
          "mlburnham/Political_DEBATE_DeBERTa_base_v1.1",
          "mlburnham/Political_DEBATE_DeBERTa_base_v1.1",
          "mlburnham/Political_DEBATE_ModernBERT_base_v1.0",
          "mlburnham/Political_DEBATE_ModernBERT_large_v1.0"]

# column names that will hold results
columns = ['base_nli',
           'large_nli',
           'base_debate',
           'large_debate',
           'base_modern',
           'large_modern']

# for each model, classify documents and return labels to the test dataframe
for modname, col in zip(models, columns):
    res = label_docs(modname, docs_dict, batch_size = 8, device = device)
    nli[col] = res
    nli[col] = nli[col].replace({'entailment': 0, 'not_entailment': 1})
    print(modname + ' complete.')

nli.to_csv('./data/nli_bench.csv', index = False)